# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import queue
import warnings
from typing import List, Optional

import pynvml
import torch
import torch.multiprocessing as mp
from torch.nn.utils.rnn import pad_sequence
from transformers import (AutoTokenizer, PreTrainedModel,
                          PreTrainedTokenizerBase)

from lmdeploy.utils import get_logger

from .model import accel_model, init_model


def safe_numel(free_mem, model_size, max_intermediate):
    """Number of elements without out-of-memory."""
    return int(free_mem - model_size) // max_intermediate


def avail_gpus(percentage=0.96):
    """Detect available gpus.

    Args:
        percentage (float): The minimum percentage of free memory to be
            considered as available.

    Return:
       A list of gpu ids.
       average free memory on single gpu.
    """

    gpus = []
    mems = []
    pynvml.nvmlInit()
    for i in range(torch.cuda.device_count()):
        handle = pynvml.nvmlDeviceGetHandleByIndex(int(i))
        mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
        free, total = int(mem_info.free), int(mem_info.total)

        if free / total > percentage:
            gpus.append(i)
            mems.append(free)
    pynvml.nvmlShutdown()

    if len(gpus) == 0:
        raise RuntimeError('No GPU available.')

    return gpus, sum(mems) / len(mems)


@torch.no_grad()
def decode_single(model: PreTrainedModel,
                  input_ids: torch.Tensor,
                  attention_mask: torch.Tensor = None,
                  return_logits=True):
    """Decode a single batch.

    Args:
        model (PreTrainedModel): Pretrained model.
        input_ids (torch.Tensor): A batch of input ids.
        attention_mask (torch.Tensor): A batch of attention masks.

    Returns:
        torch.Tensor: A batch of probabilities (on CPU).


    Note:
        This function assume input_ids[i] = [bos, x1, x2, ..., xn]
        and return prob = [p(x1|bos), p(x2|bos,x1), ..., p(xn|bos..xn-1)]
        So prob is shorter than input_ids by 1.
    """

    # Call Causal LM forward
    outputs = model(input_ids=input_ids,
                    attention_mask=attention_mask,
                    output_hidden_states=False,
                    output_attentions=False,
                    use_cache=False,
                    return_dict=True)
    # fp32, [bs, seq_len, vocab_size]
    logits = outputs.logits

    if not return_logits:
        # inplace softmax to get probs
        torch.softmax(logits, dim=-1, out=logits)

        # Shift to fetch probabilities
        shift_labels = input_ids[..., 1:].contiguous()
        shift_probs = logits[..., :-1, :].contiguous()
        logits = torch.gather(shift_probs, -1, shift_labels.unsqueeze(-1))

    if attention_mask is not None:
        logits *= attention_mask[..., None]

    logits = logits.cpu()

    return logits


def worker_fn(model_path: str,
              inq: mp.Queue,
              outq: mp.Queue,
              accel: Optional[str] = None,
              gpu_id=0):
    # torch.set_default_device(gpu_id)
    model, _ = init_model(model_path)
    model = model.eval()
    model = accel_model(model, accel, gpu_id=gpu_id)

    while True:
        try:
            idx, args = inq.get(timeout=1)
        except queue.Empty:
            continue

        if idx is None:
            print(f'Worker {gpu_id} received exit signal.')
            break

        # print(args)
        input_ids, input_lens, *args = args

        input_ids = input_ids.cuda(gpu_id)
        max_len = max(input_lens)
        assert max_len == input_ids.size(-1), \
            f'input_ids.shape = {input_ids.shape}, max_len = {max_len}'

        input_lens = torch.tensor(input_lens, device=gpu_id)
        attention_mask = \
            torch.arange(max_len, device=gpu_id)[None, :] < input_lens[:, None]

        assert attention_mask.shape == input_ids.shape, \
            f'attention_mask.shape = {attention_mask.shape}'

        try:
            probs = decode_single(model, input_ids, attention_mask, *args)
        except torch.cuda.OutOfMemoryError:
            warnings.warn(
                f'OOM on GPU {gpu_id}, discard prompts at indics {idx}.')
            probs = torch.empty((input_ids.size(0), 0),
                                dtype=torch.float32,
                                device='cpu')

        outq.put((idx, probs))

    print(f'Exiting worker {gpu_id} ...')
    inq.close()
    outq.close()
    print(f'Worker {gpu_id} finished.')


class Engine:
    """Multi-GPU deciding engine.

    Args:
        model_path (str): Path to the pretrained model.
        tokenizer_path (str, optional): Path to the pretrained tokenizer.
            Defaults to None.
            Either tokenizer_path or tokenizer should be provided.
        tokenizer (PreTrainedTokenizerBase, optional): Pre-configured tokenizer.
            Defaults to None.
            Either tokenizer_path or tokenizer should be provided.
        accel (str, optional): Acceleration method.
            Defaults to None. 'deepspeed' is not tested.
        gpu_mem_percentage (float, optional): GPU with memory larger than this value
            are considered available and be used as decode device.
            Defaults to 0.96.
        model_size_byte (float, optional): (Approximate) model size in bytes.
            Defaults to 14e9 (7B model in FP16).
        bytes_per_token (float, optional): (Approximate) memory cost per token in bytes.
            Defaults to 2e6 (2MB).
            ``bytes_per_token`` and ``model_size_byte`` are used to compute
            the maximum batch size for given seq_length
    """  # noqa: E501

    def __init__(self,
                 model_path: str,
                 tokenizer_path: Optional[str] = None,
                 tokenizer: Optional[PreTrainedTokenizerBase] = None,
                 accel: Optional[str] = None,
                 gpu_mem_percentage: float = 0.96,
                 model_size_byte=14e9,
                 bytes_per_token=2e6):

        gpu_ids, mem = avail_gpus(gpu_mem_percentage)
        print(f'Available GPUs are: {gpu_ids}, ', end='')
        print(f'with {mem/2**30:.2f} GiB free.')

        ctx = mp.get_context('spawn')
        inq = ctx.Queue()
        outq = ctx.Queue()

        ps = []
        for id in gpu_ids:
            p = ctx.Process(target=worker_fn,
                            args=(model_path, inq, outq, accel, id))
            p.start()
            ps.append(p)

        if tokenizer is None:

            if tokenizer_path is None:
                tokenizer_path = model_path

            tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)

        self.gpu_ids = gpu_ids
        self.inq = inq
        self.outq = outq
        self.ps = ps
        self.tokenizer = tokenizer
        self.safe_numel = safe_numel(mem, model_size_byte, bytes_per_token)

    def clear_queue(self):
        for q in self.inq, self.outq:
            while not q.empty():
                q.get()

    def decode(self,
               token_ids: List[List[int]],
               sort=True,
               max_bs: int = 1024,
               pad=True,
               pad_token_id=2,
               return_logits=True):
        """Inference the model to compute probabilities.

        Args:
            token_ids (List[List[int]]): List of list of token ids.
            sort (bool, optional): Internally sort the prompts by length to achieve better efficiency.
                Defaults to True.
                Note: orders of returned probabilities are always the same as the input.
            max_bs (int, optional): Maximum batch size.
                Defaults to 1024.
            pad (bool, optional): Pad the prompts in every mini batch to the same length.
                Defaults to True. Set to False to save memory.
            return_logits (bool, optional): Return logits instead of probabilities.

        Returns:
            numpy.ndarray: Array of logits of shape [bsz, seqlen, vocab_size],
                with prob=0 padded, if pad is True
            List[numpy.ndarray]: List of logits without padding, if pad is False.

        Note:
            This function will accept input token_ids = [x0(=bos), x1, x2, ..., xn]
            and compute prob = [p(x1|x0), p(x2|x0,x1), ..., p(xn|x0..xn-1)]
            So prob is shorter than input_ids by 1.
        """  # noqa: E501

        self.clear_queue()

        # sort to achieve better efficiency
        if sort:
            pids_and_indicis = sorted(enumerate(token_ids),
                                      key=lambda i_and_x: len(i_and_x[1]))
        else:
            pids_and_indicis = list(enumerate(token_ids))

        left = 0
        bs = max_bs

        while left < len(token_ids):

            if not sort:
                bs = max_bs

            right = min(left + bs, len(token_ids))

            # batch of prompts
            sub_p_and_i = pids_and_indicis[left:right]
            idx, sub_p = zip(*sub_p_and_i)

            # batch of input_ids and attn_masks
            # inputs = self.tokenizer(sub_p, return_tensors='pt', padding=True)
            input_ids = [torch.tensor(p) for p in sub_p]
            input_ids = pad_sequence(input_ids,
                                     batch_first=True,
                                     padding_value=pad_token_id)
            input_lens = [len(p) for p in sub_p]

            # Dynamic batch size based on safe memory
            while input_ids.numel() > self.safe_numel:
                if bs == 1:
                    break
                bs = max(1, round(bs / 1.5))
                print(f'\nReduce bs to {bs} when seq len reaches '
                      f'{input_ids.shape[-1]}')
                idx = idx[:bs]
                input_lens = input_lens[:bs]
                input_ids = input_ids[:bs, :max(input_lens)]

            # Send to worker
            self.inq.put((idx, (input_ids, input_lens)))

            left += bs

            print(
                f'Distributing prompts {right}/{len(token_ids)},'
                f' {right/len(token_ids):.0%}',
                end='\r')

        print()

        # Collect outputs from workers
        all_probs = [None] * len(token_ids)
        count = 0

        while count < len(token_ids):
            idx, probs = self.outq.get()
            for i, p in zip(idx, probs):
                assert all_probs[i] is None
                all_probs[i] = p

            count += len(idx)
            print(
                f'Decoding and collecting outputs '
                f'{count}/{len(token_ids)}, '
                f'{count/len(token_ids):.0%}',
                end='\r')

        print()

        if pad:
            all_probs = pad_sequence(all_probs, batch_first=True)
            all_probs = all_probs.cpu().numpy()
        else:
            all_probs = [p.cpu().numpy() for p in all_probs]

        return all_probs

    def __del__(self):
        print('Exiting engine ...')
        for _ in self.ps:
            self.inq.put((None, None))
        for p in self.ps:
            p.join(timeout=1)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_path',
                        default='llama2/huggingface/llama-2-7b',
                        help='Path to HugigngFace model and tokenizer.')
    parser.add_argument(
        '--test_path',
        default='',
        help='Path to text file, with each line containing a prompt.')
    parser.add_argument(
        '-p',
        '--prompts',
        nargs='*',
        default=[
            'I believe the meaning of life is to find your gift.',
            'Simply put, the theory of relativity states that',
            'Building a website can be done in 10 simple steps:'
        ],
        help="Prompt in command line, please quote \"\" every sentences, "
        'surpassed by --test_path')
    parser.add_argument('--min_len',
                        default=1,
                        help='Minimum length of prompts')
    parser.add_argument('--save-to',
                        default='decode.out',
                        help='Save results to this file.')
    args = parser.parse_args()

    model_path = args.model_path
    test_path = args.test_path
    prompts = args.prompts

    logger = get_logger(__name__)
    # logging.basicConfig(level=logging.DEBUG)

    # Use test file preferentially
    if test_path:
        with open(test_path, 'r') as f:
            prompts = f.readlines()

    prompts = [p.strip() for p in prompts]

    # Output infos
    print(f'Model path: {model_path}')

    def _format(ts, start, end):
        if start < 0:
            start += len(ts)
        if end <= 0:
            end += len(ts)
        return '\n'.join(
            (f'{i}\t{t}' for i, t in zip(range(start, end), ts[start:end])))

    if len(prompts) > 10:
        print('Prompts:\n' + _format(prompts, 0, 5) + '\n......\n' +
              _format(prompts, -5, 0))
    else:
        print('Prompts:\n' + _format(prompts, 0, 0))

    # Init Engine in backend
    engine = Engine(model_path)

    # Tokenize
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.padding_side = 'right'

    input_ids = tokenizer(prompts, padding=False)
    input_ids: List[List[int]] = input_ids.input_ids

    # Filter out too short prompts
    input_ids = [i for i in input_ids if len(i) >= args.min_len]
    if len(input_ids) < len(prompts):
        logger.warning(
            f'Filtered out {len(prompts) - len(input_ids)} prompts, '
            f'because they are shorter than {args.min_len}.')

    # Decode
    logits = engine.decode(input_ids)

    print(f'logits.shape = {logits.shape}')
    # Save to pth
    print(f'Dumping results to = {args.save_to}')

    torch.save(logits, args.save_to, pickle_protocol=4)

    del engine
